/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.paa;

import java.util.Arrays;

import mosdi.util.ArrayUtils;

/** Abstract base class for a probabilistic arithmetic automaton as defined in
 *  <a href="http://arxiv.org/abs/1011.5778">this arXiv article</a>.
 *  There are (at least) two ways of using the PAA framework for a new application:
 *  <p>
 *  1) Create a class derived from {@link PAA}, implementing all abstract methods.
 *  <p>
 *  2) Create a derived class from {@link DAA}, implementing all abstract methods, and 
 *  combine it with a {@link mosdi.fa.FiniteMemoryTextModel} to obtain a {@link TextBasedPAA}.
 *  This approach is preferable when the modelled problem can be seen as computing
 *  the distribution of the result of a deterministic computation performed on a random sequence. 
 *  <p>
 *  NOTE: To speed up
 *  calculations, this class 
 *  pre-computes and caches all transition probabilities, emission probabilities
 *  and all arithmetic computations by calling the implementations provided by deriving
 *  classes. If this behaviour is not desired, the methods
 *  {@link PAA#getTargets(int)}, {@link PAA#getTargetProbabilities(int)}, {@link PAA#getEmissions(int)}, 
 *  {@link PAA#getEmissionProbabilities(int)}, and {@link PAA#performOperation(int, int, int)}
 *  must be overwritten.
 */
// TODO: check whether arithmetic operations are indeed cached.
public abstract class PAA {
	/** Preprocessed table of target states to be used 
	 *  in getTargets(). */
	private int[][] targets;
	/** Preprocessed table of transition probabilities to be used 
	 *  in getTargetProbabilities(). */
	private double[][] targetProbabilities;
	/** Preprocessed table of emission probabilities to be used 
	 *  in getEmissions(). */
	private int[][] emissions;
	/** Preprocessed table of emission probabilities to be used 
	 *  in getEmissionProbabilities(). */
	private double[][] emissionProbabilities;

	/** Returns the size of the state set. */
	public abstract int getStateCount();
	/** Returns the size of the value set. */
	public abstract int getValueCount();
	/** Returns the size of the emission set, that is, the number of different possible emission. */
	public abstract int getEmissionCount();

	/** Returns the start value, that is, the PAA's value at time 0. */
	public abstract int getStartValue();
	/** Returns the start state, that is, the PAA's state at time 0. */
	public abstract int getStartState();

	/** Returns the probability of going from given <code>state</code> to given <code>targetState</code>. */
	public abstract double transitionProbability(int state, int targetState);

	/** Returns the probability that the given <code>state</code> produces the given <code>emission</code>. */
	public abstract double emissionProbability(int state, int emission);

	/** Returns the value that results from applying the operation associated with the given
	 *  <code>state</code> to the given <code>value</code> and the given <code>emission</code>. */
	public abstract int performOperation(int state, int value, int emission);
	
	/** Returns the operation if it can be cast into a {@link SimpleOperation}, otherwise returns null. */
	public SimpleOperation getOperation() { return null; }

	/** Returns an array of targets of a given state.
	 *  A state j is containted in getTargets(i) iff transitionProbability(i,j)>0.0.
	 */
	protected int[] getTargets(int state) {
		if (targets==null) {
			targets = new int[getStateCount()][];
			for (int i=0; i<getStateCount(); ++i) {
				targets[i] = new int[getStateCount()];
				int n = 0;
				for (int j=0; j<getStateCount(); ++j) {
					if (transitionProbability(i,j)>0.0) targets[i][n++] = j;
				}
				targets[i] = Arrays.copyOf(targets[i], n);
			}
		}
		return targets[state];
	}

	/** Returns array of transition probabilities for outgoing transitions from the 
	 *  given <code>state</code>. The returned array is consistent with the result
	 *  of {@link PAA#getTargets(int)}, i.e. <code>getTargetProbabilities(q)[i]</code> 
	 *  is the probability if going from state <code>q</code> to state <code>getTargets(q)[i]</code>.
	 */
	protected double[] getTargetProbabilities(int state) {
		if (targetProbabilities==null) {
			targetProbabilities = new double[getStateCount()][];
			for (int i=0; i<getStateCount(); ++i) {
				int[] t = getTargets(i);
				targetProbabilities[i] = new double[t.length];
				int n = 0;
				for (int j : t) targetProbabilities[i][n++] = transitionProbability(i,j);
			}
		}
		return targetProbabilities[state];
	}

	/** Returns an array of possible emissions of the given <code>state</code>, that is,
	 *  the value <code>z</code> is contained in <code>getEmissions(q)</code> iff 
	 *  <code>emissionProbability(q,z)>0.0</code>.
	 */
	protected int[] getEmissions(int state) {
		if (emissions==null) {
			emissions = new int[getStateCount()][];
			for (int i=0; i<getStateCount(); ++i) {
				emissions[i] = new int[getStateCount()];
				int n = 0;
				for (int e=0; e<getEmissionCount(); ++e) {
					if (emissionProbability(i,e)>0.0) emissions[i][n++] = e;
				}
				emissions[i] = Arrays.copyOf(emissions[i], n);
			}
		}
		return emissions[state];
	}

	/** Returns array of emission probabilities of given state 
	 *  The returned array is consistent with the result
	 *  of getEmissions(), i.e. getEmissionProbabilities(q)[i] the 
	 *  probability of the emission e = getEmissions(q)[i].
	 */
	protected double[] getEmissionProbabilities(int state) {
		if (emissionProbabilities==null) {
			emissionProbabilities = new double[getStateCount()][];
			for (int i=0; i<getStateCount(); ++i) {
				int[] e = getEmissions(i);
				emissionProbabilities[i] = new double[e.length];
				int n = 0;
				for (int z : e) emissionProbabilities[i][n++] = emissionProbability(i,z);
			}
		}
		return emissionProbabilities[state];
	}

	/** Returns the joint distribution of state and value at time 0, that is,
	 *  before any state transition has been done. By default this is the 
	 *  Dirac distribution assigning probability one to start state and 
	 *  start value. This methods might be overwritten by deriving classes to
	 *  change this behaviour, e.g. to start from an equilibrium distribution.
	 */
	public double[][] stateValueStartDistribution() {
		double[][] result = new double[getStateCount()][getValueCount()];
		result[getStartState()][getStartValue()] = 1.0;
		return result;
	}

	/** Performs one time step. 
	 *  @param distribution Joint state-value distribution before the time step.
	 *  @return Joint state-value distribution after the time step.
	 */
	public double[][] updateStateValueDistribution(double[][] distribution) {
		double[][] result = new double[getStateCount()][getValueCount()];
		// iterate over all source states
		for (int sourceState=0; sourceState<getStateCount(); ++sourceState) {
			int[] targetStates = getTargets(sourceState);
			double[] probs = getTargetProbabilities(sourceState);
			// iterate over all target states that can be reached from source state
			for (int i=0; i<targetStates.length; ++i) {
				int targetState = targetStates[i];
				double transitionProb = probs[i];
				// use simplified recurrence of emissions are deterministic
				if (this instanceof DeterministicEmitter) {
					int emission = ((DeterministicEmitter)this).getEmission(targetState);
					// iterate over all possible former values
					for (int value=0; value<getValueCount(); ++value) {
						result[targetState][performOperation(targetState, value, emission)]+=
							distribution[sourceState][value] * transitionProb;
					}
				} else {
					int[] emissions = getEmissions(targetState);
					double[] emissionProbs = getEmissionProbabilities(targetState);
					// iterate over all emissions that can be emitted by target state
					for (int j=0; j<emissions.length; ++j) {
						int emission = emissions[j];
						double emissionProb = emissionProbs[j];
						// iterate over all possible former values
						for (int value=0; value<getValueCount(); ++value) {
							result[targetState][performOperation(targetState, value, emission)]+=
								distribution[sourceState][value] * transitionProb * emissionProb;
						}
					}
				}
			}
		}
		return result;
	}

	/** Performs one time step, just considering states, ignoring values.
	 *  @param distribution State distribution before time step.
	 *  @return State distribution after time step.
	 */
	public double[] updateStateDistribution(double[] distribution) {
		double[] result = new double[getStateCount()];
		// iterate over all source states
		for (int sourceState=0; sourceState<getStateCount(); ++sourceState) {
			int[] targetStates = getTargets(sourceState);
			double[] probs = getTargetProbabilities(sourceState);
			// iterate over all target states that can be reached from source state
			for (int i=0; i<targetStates.length; ++i) {
				int targetState = targetStates[i];
				double transitionProb = probs[i];
				result[targetState] += distribution[sourceState] * transitionProb;
			}
		}
		return result;
	}

	/** Helper method needed for the doubling technique. */
	private double[][][] joinDistributions(double[][][] dist0, double[][][] dist1) {
		int stateCount = getStateCount();
		int valueCount = getValueCount();
		SimpleOperation op = getOperation();
		double[][][] resultDist = new double[stateCount][stateCount][valueCount];
		for (int state0=0; state0<stateCount; ++state0) {
			for (int state1=0; state1<stateCount; ++state1) {
				for (int state2=0; state2<stateCount; ++state2) {
					for (int value1=0; value1<valueCount; ++value1) {
						for (int value2=0; value2<valueCount; ++value2) {
							resultDist[state0][state2][op.joinValues(value1, value2)] +=
								dist0[state0][state1][value1] * dist1[state1][state2][value2];
						}
					}
				}
			}
		}
		return resultDist;
	}

	/** Equivalent to {@link PAA#computeStateValueDistribution(long)} but has
	 *  different runtime characteristics. The runtime is logarithmic in the number
	 *  of <code>steps</code> but cubic in the number of states and in the size of the
	 *  value set. */
	public double[][] stateValueDistributionViaDoubling(long steps) {
		if ((getOperation()==null) || !(this instanceof DeterministicEmitter)) {
			throw new IllegalStateException("Not yet implemented.");
		}
		int stateCount = getStateCount();
		int valueCount = getValueCount();
		int bitsNeeded = (int)(Math.ceil(Math.log(steps) / Math.log(2))+1.0);
		// distributions dist[k][state0][state1][value] is the probability of
		// going from state0 to state1 in 2^k steps and having changed the start 
		// value into value.
		double[][][][] dists = new double[bitsNeeded][][][];

		// initialize first distribution
		dists[0] = new double[stateCount][stateCount][valueCount];
		for (int sourceState=0; sourceState<stateCount; ++sourceState) {
			int[] targetStates = getTargets(sourceState);
			double[] targetProbabilities = getTargetProbabilities(sourceState);
			for (int i=0; i<targetStates.length; ++i) {
				int targetState = targetStates[i];
				int value = performOperation(targetState, getStartValue(), ((DeterministicEmitter)this).getEmission(targetState));
				dists[0][sourceState][targetState][value] = targetProbabilities[i];
			}
		}
		// create all needed distributions
		for (int bit=1; bit<bitsNeeded; ++bit) {
			if ((1L<<bit)>steps) break;
			dists[bit] = joinDistributions(dists[bit-1],dists[bit-1]);
			// do we still need the previous distribution?
			if (((1L<<(bit-1))&steps)==0) dists[bit-1]=null;
		}
		// join distributions to create final
		double[][][] targetDist = null;
		for (double[][][] d : dists) {
			if (d==null) continue;
			if (targetDist==null) {
				targetDist=d;
			} else {
				targetDist=joinDistributions(targetDist,d);
			}
		}
		double[][] startDist = stateValueStartDistribution();
		double[][] result = new double[stateCount][valueCount];
		for (int startState=0; startState<stateCount; ++startState) {
			double p = startDist[startState][getStartValue()]; 
			if (p==0.0) continue;
			for (int endState=0; endState<stateCount; ++endState) {
				for (int endValue=0; endValue<valueCount; ++endValue) {
					result[endState][endValue] += p*targetDist[startState][endState][endValue];
				}	
			}
		}
		return result;
	}

	/** Return the joint distribution of PAA state and value after the 
	 *  given number of <code>iterations</code>. At time 0, the joint 
	 *  state-value distribution is assumed to be the one given by
	 *  {@link PAA#stateValueStartDistribution()}. */
	public double[][] computeStateValueDistribution(long iterations) {
		double[][] distribution = stateValueStartDistribution();
		// iterate over all times
		for  (long t=0; t<iterations; ++t) {
			distribution = updateStateValueDistribution(distribution);
		}
		return distribution;
	}

	/** Marginalizes the given joint distribution of states and values to obtain a
	 *  distribution of values. */
	public double[] toValueDistribution(double[][] stateValueDistribution) {
		double[] result = new double[getValueCount()];
		for (int value=0; value<getValueCount(); ++value) {
			for (int state=0; state<getStateCount(); ++state) {
				result[value]+=stateValueDistribution[state][value];
			}
		}
		return result;
	}
	
	/** Computes the value distribution after the given number of <code>iterations</code> by
	 *  computing the joint state-value distribution and marginalizing it. At time 0, 
	 *  the joint state-value distribution is assumed to be the one given by
	 *  {@link PAA#stateValueStartDistribution()}. */
	public double[] computeValueDistribution(int iterations) {
		return toValueDistribution(computeStateValueDistribution(iterations));
	}
	
	/** Computes the value distribution after the given number of <code>iterations</code> by
	 *  computing the joint state-value distribution using the doubling algorithm 
	 *  (see {@link PAA#stateValueDistributionViaDoubling}) and marginalizing it. At time 0, 
	 *  the joint state-value distribution is assumed to be the one given by
	 *  {@link PAA#stateValueStartDistribution()}. */
	public double[] computeValueDistributionViaDoubling(long iterations) {
		return toValueDistribution(stateValueDistributionViaDoubling(iterations));
	}

	/** Updates the joint distribution of state and value until convergence to the 
	 *  given accuracy has been reached. 
	 *  @return Equilibrium distribution.
	 *  @throws NoConvergenceException if stepLimit has been exceeded.
	 */
	public double[][] convergeToStateValueEquilibrium(double accuracy, int stepLimit) {
		double[][] distribution = stateValueStartDistribution();
		double maxDiff = Double.POSITIVE_INFINITY;
		for  (int t=0; (t<=getStateCount()) && (maxDiff>accuracy); ++t) {
			if (t==stepLimit) throw new NoConvergenceException();
			double[][] newDistribution = updateStateValueDistribution(distribution);
			// check if equilibrium is reached
			maxDiff = 0.0;
			for (int state=0; state<getStateCount(); ++state) {
				for (int value=0; value<getValueCount(); ++value) {
					double diff = Math.abs(1.0 - newDistribution[state][value]/distribution[state][value]);
					if (diff>maxDiff) maxDiff=diff;
				}
			}
			distribution = newDistribution;
		}
		return distribution;
	}

	/** Updates the state distribution until convergence to the given accuracy has been reached. 
	 *  @return Equilibrium distribution.  
	 *  @throws NoConvergenceException if stepLimit has been exceeded.
	 */
	public double[] convergeToStateEquilibrium(double accuracy, int stepLimit) {
		double[] distribution = new double[getStateCount()];
		int state = 0;
		for (double[] valueDist : stateValueStartDistribution()) {
			distribution[state++] = ArrayUtils.sum(valueDist);
		}
		double maxDiff = Double.POSITIVE_INFINITY;
		for  (int t=0; maxDiff>accuracy; ++t) {
			if (t==stepLimit) throw new NoConvergenceException();
			double[] newDistribution = updateStateDistribution(distribution);
			// check if equilibrium is reached
			maxDiff = 0.0;
			for (state=0; state<getStateCount(); ++state) {
				double diff = Math.abs(1.0 - newDistribution[state]/distribution[state]);
				if (diff>maxDiff) maxDiff=diff;
			}
			distribution = newDistribution;
		}
		return distribution;
	}

	/** Returns the distribution of the waiting time for a given value. At time 0, 
	 *  the joint state-value distribution is assumed to be the one given by
	 *  {@link PAA#stateValueStartDistribution()}.
	 *  
	 *  @param maxTime The returned distribution will have size <code>maxTime+1</code>.
	 *  @param value The value to be waited for.
	 */
	public double[] waitingTimeForValue(int maxTime, int value) {
		double[] result = new double[maxTime+1];
		double[][] distribution = stateValueStartDistribution();
		// iterate over all times
		for  (int t=0; t<=maxTime; ++t) {
			for (int state=0; state<getStateCount(); ++state) {
				result[t] += distribution[state][value];
				distribution[state][value] = 0;
			}
			if (t==maxTime) break;
			distribution = updateStateValueDistribution(distribution);			
		}
		return result;
	}
}
